<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Image classification

<Youtube id="tjAIM7BOYhw"/>

Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that represent an image. There are many uses for image classification, like detecting damage after a disaster, monitoring crop health, or helping screen medical images for signs of disease.

This guide will show you how to fine-tune [ViT](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vit) on the [Food-101](https://huggingface.co/datasets/food101) dataset to classify a food item in an image.

<Tip>

See the image classification [task page](https://huggingface.co/tasks/audio-classification) for more information about its associated models, datasets, and metrics.

</Tip>

## Load Food-101 dataset

Load only the first 5000 images of the Food-101 dataset from the 🤗 Datasets library since it is pretty large:

```py
>>> from datasets import load_dataset

>>> food = load_dataset("food101", split="train[:5000]")
```

Split this dataset into a train and test set:

```py
>>> food = food.train_test_split(test_size=0.2)
```

Then take a look at an example:

```py
>>> food["train"][0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
 'label': 79}
```

The `image` field contains a PIL image, and each `label` is an integer that represents a class. Create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number:

```py
>>> labels = food["train"].features["label"].names
>>> label2id, id2label = dict(), dict()
>>> for i, label in enumerate(labels):
...     label2id[label] = str(i)
...     id2label[str(i)] = label
```

Now you can convert the label number to a label name for more information:

```py
>>> id2label[str(79)]
'prime_rib'
```

Each food class - or label - corresponds to a number; `79` indicates a prime rib in the example above.

## Preprocess

Load the ViT feature extractor to process the image into a tensor:

```py
>>> from transformers import AutoFeatureExtractor

>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
```

Apply several image transformations to the dataset to make the model more robust against overfitting. Here you'll use torchvision's [`transforms`](https://pytorch.org/vision/stable/transforms.html) module. Crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:

```py
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
```

Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:

```py
>>> def transforms(examples):
...     examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
...     del examples["image"]
...     return examples
```

Use 🤗 Dataset's [`with_transform`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?#datasets.Dataset.with_transform) method to apply the transforms over the entire dataset. The transforms are applied on-the-fly when you load an element of the dataset:

```py
>>> food = food.with_transform(transforms)
```

Use [`DefaultDataCollator`] to create a batch of examples. Unlike other data collators in 🤗 Transformers, the DefaultDataCollator does not apply additional preprocessing such as padding.

```py
>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator()
```

## Train

<frameworkcontent>
<pt>
Load ViT with [`AutoModelForImageClassification`]. Specify the number of labels, and pass the model the mapping between label number and label class:

```py
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

>>> model = AutoModelForImageClassification.from_pretrained(
...     "google/vit-base-patch16-224-in21k",
...     num_labels=len(labels),
...     id2label=id2label,
...     label2id=label2id,
... )
```

<Tip>

If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](../training#finetune-with-trainer)!

</Tip>

At this point, only three steps remain:

1. Define your training hyperparameters in [`TrainingArguments`]. It is important you don't remove unused columns because this will drop the `image` column. Without the `image` column, you can't create `pixel_values`. Set `remove_unused_columns=False` to prevent this behavior!
2. Pass the training arguments to [`Trainer`] along with the model, datasets, tokenizer, and data collator.
3. Call [`~Trainer.train`] to fine-tune your model.

```py
>>> training_args = TrainingArguments(
...     output_dir="./results",
...     per_device_train_batch_size=16,
...     evaluation_strategy="steps",
...     num_train_epochs=4,
...     fp16=True,
...     save_steps=100,
...     eval_steps=100,
...     logging_steps=10,
...     learning_rate=2e-4,
...     save_total_limit=2,
...     remove_unused_columns=False,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     data_collator=data_collator,
...     train_dataset=food["train"],
...     eval_dataset=food["test"],
...     tokenizer=feature_extractor,
... )

>>> trainer.train()
```
</pt>
</frameworkcontent>

<Tip>

For a more in-depth example of how to fine-tune a model for image classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).

</Tip>